-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-9771][feat] Support partial update weight for fp8 #10456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRTLLM-9771][feat] Support partial update weight for fp8 #10456
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #30902 [ run ] triggered by Bot. Commit: |
|
PR_Github #30902 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #30988 [ run ] triggered by Bot. Commit: |
|
PR_Github #30988 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #31066 [ run ] triggered by Bot. Commit: |
|
PR_Github #31066 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #31135 [ run ] triggered by Bot. Commit: |
|
PR_Github #31135 [ run ] completed with state
|
1e5746f to
c37c80d
Compare
📝 WalkthroughWalkthroughThis pull request introduces a weight loading lifecycle with pre-reload and post-load processing hooks across MoE and linear modules. It adds partial loading support with dynamic parameter inspection and implements temporary scale storage for weight reconstruction during quantized loading. Changes
Sequence Diagram(s)sequenceDiagram
participant Module as MoE Module
participant Interface as MoE Interface
participant QuantMethod as Quantization Method
Module->>Interface: pre_reload_weights()
Interface->>QuantMethod: pre_reload_weights(module)
QuantMethod-->>Module: Cleanup/Recreate tensors
Module->>Interface: load_weights(weights, allow_partial_loading=False)
Interface->>QuantMethod: load_weights(module, weights, ..., allow_partial_loading=False)
QuantMethod->>QuantMethod: Load weights & temporary scales
QuantMethod-->>Module: Weights loaded
Note over Module: If allow_partial_loading=False
Module->>Interface: process_weights_after_loading()
Interface->>QuantMethod: process_weights_after_loading(module)
QuantMethod->>QuantMethod: Finalize scales, recompute metadata
QuantMethod->>Module: replace_parameter_and_save_metadata()
QuantMethod-->>Module: Post-processing complete
sequenceDiagram
participant Test as Test Code
participant TempDir as Temporary Directory
participant RefModel as RefHFModel
participant LLM as LLM Engine
Test->>TempDir: Create temporary model dir
Test->>TempDir: Copy model files, adjust config
Test->>RefModel: Initialize with temp model dir
Test->>LLM: Initialize with same model dir
Test->>RefModel: pad_data(prompts, responses)
RefModel-->>Test: Padded input_ids, attention_mask
Test->>RefModel: generate_batch_with_padding(...)
RefModel->>RefModel: Micro-batch inference
RefModel-->>Test: Reference logits
Test->>LLM: generate(prompts)
LLM-->>Test: LLM logits
Test->>Test: Compare logits for accuracy
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 15
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/utils.py (1)
1-15: Add the NVIDIA copyright header (file was modified).This file doesn’t currently include the required NVIDIA header despite being updated in this PR. Please add the appropriate header with the latest modification year.
🤖 Fix all issues with AI agents
In @tensorrt_llm/_torch/modules/fused_moe/quantization.py:
- Around line 491-503: The shape-check logic around w1_weight_shard_viewed vs
dst_w3_w1_weight/dst_w1_weight is fragile because dtype-viewing can hide the
original shape; update the ValueError to include both source and destination
shapes and dtypes (e.g., include w1_weight_shard.shape,
w1_weight_shard_viewed.shape, dst_w3_w1_weight.shape, dst_w1_weight.shape and
their .dtype) and/or perform the comparison using the pre-view shape
(w1_weight_shard.shape) before reinterpreting dtype; ensure the same pattern is
applied for w3_weight_shard checks (referencing dst_w3_weight, dst_w1_weight,
w3_weight_shard and any viewed tensors) so mismatches surface clear diagnostic
info.
- Around line 707-754: The tmp scale buffers (module.tmp_w3_w1_weight_scale,
module.tmp_w2_weight_scale, module.tmp_fc31_input_scale,
module.tmp_fc2_input_scale) must be initialized to safe values (e.g., zeros or
-inf for max-reduction) instead of torch.empty(), and we must record per-slot
whether a scale was loaded so downstream max()/aggregation in
process_weights_after_loading() is not called on uninitialized memory; update
the allocation in load_activation_scales_fp8_qdq / the block creating tmp_* to
use a deterministic fill (torch.zeros or torch.full(..., -float('inf')) as
appropriate) and add a boolean mask or loaded_count per expert that
load_expert_w3_w1_weight_scale_fp8_qdq and load_expert_w2_weight_scale_fp8 set
when they actually load a scale, then in process_weights_after_loading() check
that mask/count for each slot and skip aggregation or apply a safe fallback when
no scales were loaded for that slot.
- Around line 565-570: Replace placeholder f-strings like
f"gate_up_proj_input_scale" and f"down_proj_input_scale" with plain string
literals; in load_activation_scales_fp8_qdq() stop storing temporary tensors as
long-lived module attributes
(module.tmp_fc31_input_scale/module.tmp_fc2_input_scale) or ensure they are
removed immediately after they are copied into
dst_fc31_input_scale/dst_fc2_input_scale—e.g., copy the values then
delattr(module, "tmp_fc31_input_scale") / delattr(module, "tmp_fc2_input_scale")
or set them to None so partial-loading doesn't leak buffers; update
process_weights_after_loading() behavior to be consistent but do not rely on it
for cleanup when partial-loading is enabled; finally add the required NVIDIA
copyright header at the top of the file.
- Around line 880-927: The FUSED_GATE_UP_PROJ branch calls
load_weight_shard(w3_scale, ...) without checking that w3_scale is not None,
which will raise when "gate_up_proj_weight_scale" is absent; mirror the VANILLA
branch behavior by adding a guard: only call load_weight_shard for w3_scale and
copy into dst_w3_w1_weight_scale[local_slot_id] if w3_scale is not None (similar
to the existing if w1_scale is not None / if w3_scale is not None checks),
ensuring w3_scale is validated before invoking load_weight_shard and copy_. Use
the same load_weight_shard parameters and TensorParallelMode.COLUMN as in the
current FUSED_GATE_UP_PROJ block.
In @tensorrt_llm/_torch/modules/linear.py:
- Around line 635-640: In process_weights_after_loading_vanilla, avoid
unconditionally calling delattr(module, "has_static_input_scale") which can
raise AttributeError when the attribute was never set; change to remove the
attribute only if present (e.g., check hasattr(module, "has_static_input_scale")
before deleting or wrap delattr in try/except AttributeError) while preserving
the existing logic that clears module.input_scale and module.inv_input_scale
when has_static_input_scale is False.
- Around line 1010-1013: The offset-to-block mapping currently uses
math.ceil(offset / divisible_factor) which is incorrect for coordinates; change
the logic to use floor division (offset // divisible_factor) when mapping an
offset to its scale block so positions 0-127 map to block 0 and 128-255 to block
1, while keeping math.ceil for the size/span calculation; update any function or
method that computes block index from an `offset` (and any place referencing
`divisible_factor` or `size`) to use `offset // divisible_factor` for offsets
and retain `math.ceil(size / divisible_factor)` for sizes.
In @tensorrt_llm/llmapi/rlhf_utils.py:
- Line 107: The code unconditionally calls
delattr(self.engine.model_engine.model, "first_pre_reload_weights") which can
raise AttributeError if that attribute was never set (e.g., update_weights
called with ipc_handles=None on the finalize path); update the logic in the
update_weights method to either check hasattr(self.engine.model_engine.model,
"first_pre_reload_weights") before calling delattr, or wrap the delattr call in
a try/except AttributeError and ignore the exception, referencing the attribute
name "first_pre_reload_weights" and the object self.engine.model_engine.model to
locate the change.
In
@tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py:
- Around line 20-44: Refactor RefHFModelWithIPCHandles._replicate_weights to
avoid copying model weights to every CUDA device; instead only store the cloned
weights for self.device_id in self.all_weights and do not iterate over
torch.cuda.device_count(); add (or update) a lazy materialization path that
converts weights to another device only when requested (e.g., from
get_weight_ipc_handles or a new method like materialize_weights_on(device_idx)),
so other devices are created on-demand and prevent OOM on CI.
- Around line 121-147: In process_and_copy_folder, rename the unused loop
variable dirs to _dirs to satisfy Ruff B007, and replace the explicit existence
check plus os.makedirs for dest_dir with a single os.makedirs(dest_dir,
exist_ok=True) call so directory creation is idempotent; update the for root,
_dirs, files in os.walk(...) and change the dest_dir handling to use
os.makedirs(..., exist_ok=True) before writing or copying files.
In @tests/unittest/utils/torch_ref.py:
- Around line 1343-1427: The pad_data static method contains dead code and a
Python-version-incompatible pattern: remove the unused response_lens list
(created at response_lens = [] and appended to inside the loop) and eliminate
any reliance on zip(..., strict=...) by keeping the current explicit pairing
approach (use range(len(original_prompts)) or manually check lengths) so
pad_data (the function) only iterates with indices to validate and populate
batch tensors; ensure no references to response_lens remain and that input
validation compares lengths of original_prompts and generated_token_ids_list
explicitly before building tensors.
🧹 Nitpick comments (13)
tests/unittest/utils/torch_ref.py (1)
16-23: Import style: keep module namespaces (per guidelines).New imports like
from transformers import AutoModelForCausalLMviolate the “maintain namespace” rule in the provided guidelines. Considerimport transformersand thentransformers.AutoModelForCausalLM.tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py (2)
18-26: Import style: keep module namespaces (per guidelines).New imports like
from transformers import AutoTokenizerviolate the “maintain namespace” rule in the provided guidelines.
31-73:asyncio.get_event_loop()inside coroutine context is brittle on newer Python.Inside
generate_batch_async, preferasyncio.get_running_loop();get_event_loop()behavior changed across Python versions and can break under some runners.Proposed patch
- loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop()tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py (1)
1-17: Import style: keep module namespaces (per guidelines).New imports like
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizerviolate the “maintain namespace” rule in the provided guidelines.tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)
932-958: Consider caching the introspection result (and fix file mode if EXE002 is real).
inspect.getfullargspec(...)runs everyload_weights()call; you could compute a boolean once whenquant_methodis created.- Ruff EXE002 suggests the file is executable without a shebang; if that’s due to git mode, it’s better fixed by changing the file mode rather than adding a shebang to a library module.
tensorrt_llm/llmapi/rlhf_utils.py (1)
53-59: Consider using direct attribute assignment instead ofsetattr.The static analysis tool flags this as
setattrwith a constant attribute name. Direct assignment is cleaner and equivalent:- setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + self.engine.model_engine.model.first_pre_reload_weights = Truetensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
866-871: Consider renamingkargstokwargsfor Python convention.The variable name
kargsis unconventional. The standard Python convention iskwargsfor keyword arguments:- kargs = {} + kwargs = {} if "allow_partial_loading" in inspect.getfullargspec( self.quant_method.load_weights).args: - kargs["allow_partial_loading"] = allow_partial_loading + kwargs["allow_partial_loading"] = allow_partial_loading self.quant_method.load_weights(self, weights, self.weight_loading_mode, - **kargs) + **kwargs)tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1)
277-282: Consider renamingkargstokwargsfor Python convention.Same as in
fused_moe_cutlass.py, the variable name should follow the standard convention:- kargs = {} + kwargs = {} if "allow_partial_loading" in inspect.getfullargspec( self.quant_method.load_weights).args: - kargs["allow_partial_loading"] = allow_partial_loading + kwargs["allow_partial_loading"] = allow_partial_loading self.quant_method.load_weights(self, weights, self.weight_loading_mode, - **kargs) + **kwargs)tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)
21-22: Import style conflicts with repo Python guidelines (namespace-preserving imports).
from ...utils import (replace_parameter_and_save_metadata, swizzle_sf, unswizzle_sf)violates “Always maintain the namespace when importing Python modules…” in the provided coding guidelines; preferfrom ... import utils as torch_utils(or similar) and calltorch_utils.replace_parameter_and_save_metadata(...).
447-449: Make the post-load hook contract explicit (docstring + expected callers).
process_weights_after_loading()is currently a silent no-op at the base, but subclasses (e.g., FP8QDQ) rely on it for final scale computation and tmp_* cleanup. Consider documenting who calls it (especially in partial-loading flows) and when.
859-866: Redundant override:DeepSeekFP8BlockScalesFusedMoEMethod.load_weights()just callssuper().Unless you’re pinning the signature for readability/API reasons, this can be removed to reduce noise.
939-970: Constantgetattr/setattrusage is unnecessary here; direct attribute access is clearer.These are fixed attribute names; using
if hasattr(...): ... else: module.local_shared_* = ...improves readability (and satisfies Ruff B009/B010).Also applies to: 971-985
tensorrt_llm/_torch/modules/linear.py (1)
602-619:load_weight_scales()typing: avoid implicit Optional + considerzip(..., strict=...)only if Python>=3.10.If you move to
Optional[List[str]], this becomes clearer. Also,zip(shard_keys, weights)silently truncates on mismatch; if the project is Python 3.10+ you can addstrict=True—but that conflicts with the stated 3.8 guideline.Also applies to: 621-634
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
tensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.pytensorrt_llm/_torch/modules/fused_moe/interface.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/utils.pytensorrt_llm/llmapi/rlhf_utils.pytests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.pytests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.pytests/unittest/utils/torch_ref.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g.,some_file.py)
Python classes should use PascalCase (e.g.,class SomeClass)
Python functions and methods should use snake_case (e.g.,def my_awesome_function():)
Python local variables should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL)
Python constants should use upper snake_case (e.g.,MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format"""<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_torch/utils.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.pytensorrt_llm/llmapi/rlhf_utils.pytests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.pytensorrt_llm/_torch/modules/fused_moe/interface.pytests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.pytensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytests/unittest/utils/torch_ref.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytensorrt_llm/_torch/modules/linear.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification
Files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_torch/utils.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.pytensorrt_llm/llmapi/rlhf_utils.pytests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.pytensorrt_llm/_torch/modules/fused_moe/interface.pytests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.pytensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytests/unittest/utils/torch_ref.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytensorrt_llm/_torch/modules/linear.py
🧠 Learnings (6)
📚 Learning: 2025-08-14T23:23:27.449Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Applied to files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
📚 Learning: 2025-09-19T21:28:13.751Z
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.
Applied to files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
📚 Learning: 2025-08-21T21:48:35.135Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:399-417
Timestamp: 2025-08-21T21:48:35.135Z
Learning: CUTLASS extensions in TensorRT-LLM (located under cpp/tensorrt_llm/cutlass_extensions/) are designed to integrate with and extend functionality in the external CUTLASS repository. When analyzing these extensions, their consumers and functionality wiring may exist in the CUTLASS codebase rather than within TensorRT-LLM itself.
Applied to files:
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
📚 Learning: 2025-09-16T09:30:09.716Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7763
File: cpp/tensorrt_llm/CMakeLists.txt:297-301
Timestamp: 2025-09-16T09:30:09.716Z
Learning: In the TensorRT-LLM project, NCCL libraries are loaded earlier by PyTorch libraries or the bindings library, so the main shared library doesn't need NCCL paths in its RPATH - the libraries will already be available in the process address space when needed.
Applied to files:
tensorrt_llm/llmapi/rlhf_utils.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py
📚 Learning: 2025-09-03T13:16:06.824Z
Learnt from: nvpohanh
Repo: NVIDIA/TensorRT-LLM PR: 7478
File: tensorrt_llm/_torch/models/modeling_llama.py:1315-1315
Timestamp: 2025-09-03T13:16:06.824Z
Learning: The Llama4VisionEncoder.load_weights method signature is `def load_weights(self, weights: Dict)` and should not be confused with Llama4ForConditionalGeneration.load_weights which has a different signature including weight_mapper parameter.
Applied to files:
tensorrt_llm/_torch/modules/fused_moe/interface.py
🧬 Code graph analysis (8)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (4)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (4)
load_weights(859-871)post_load_weights(873-874)process_weights_after_loading(876-878)pre_reload_weights(880-884)tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (4)
load_weights(1392-1400)post_load_weights(1402-1403)process_weights_after_loading(1405-1407)pre_reload_weights(1409-1413)tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (4)
load_weights(932-944)post_load_weights(946-947)process_weights_after_loading(949-951)pre_reload_weights(953-957)tensorrt_llm/_torch/modules/fused_moe/interface.py (4)
load_weights(524-527)post_load_weights(532-533)process_weights_after_loading(529-530)pre_reload_weights(535-536)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (4)
load_weights(859-871)post_load_weights(873-874)process_weights_after_loading(876-878)pre_reload_weights(880-884)tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (4)
load_weights(1392-1400)post_load_weights(1402-1403)process_weights_after_loading(1405-1407)pre_reload_weights(1409-1413)tensorrt_llm/_torch/modules/fused_moe/interface.py (4)
load_weights(524-527)post_load_weights(532-533)process_weights_after_loading(529-530)pre_reload_weights(535-536)
tensorrt_llm/llmapi/rlhf_utils.py (9)
tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py (1)
model(169-172)tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)
pre_reload_weights(532-540)process_weights_after_loading(447-448)process_weights_after_loading(755-783)tensorrt_llm/_torch/modules/linear.py (4)
pre_reload_weights(513-520)pre_reload_weights(2573-2577)process_weights_after_loading(408-420)process_weights_after_loading(2567-2568)tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py (2)
pre_reload_weights(1137-1145)process_weights_after_loading(1127-1135)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (2)
pre_reload_weights(880-884)process_weights_after_loading(876-878)tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (2)
pre_reload_weights(1409-1413)process_weights_after_loading(1405-1407)tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (2)
pre_reload_weights(291-295)process_weights_after_loading(287-289)tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)
pre_reload_weights(953-957)process_weights_after_loading(949-951)tensorrt_llm/_torch/modules/fused_moe/interface.py (2)
pre_reload_weights(535-536)process_weights_after_loading(529-530)
tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py (2)
tests/unittest/utils/torch_ref.py (3)
RefHFModel(1245-1427)pad_data(1344-1427)generate_batch_with_padding(1255-1341)tensorrt_llm/_torch/utils.py (2)
get_device_uuid(362-367)_(228-234)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (6)
tensorrt_llm/_torch/modules/linear.py (4)
process_weights_after_loading(408-420)process_weights_after_loading(2567-2568)pre_reload_weights(513-520)pre_reload_weights(2573-2577)tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py (3)
process_weights_after_loading(1127-1135)quant_method(1176-1178)pre_reload_weights(1137-1145)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (2)
process_weights_after_loading(876-878)pre_reload_weights(880-884)tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (2)
process_weights_after_loading(287-289)pre_reload_weights(291-295)tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)
process_weights_after_loading(949-951)pre_reload_weights(953-957)tensorrt_llm/_torch/modules/fused_moe/interface.py (2)
process_weights_after_loading(529-530)pre_reload_weights(535-536)
tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py (1)
tests/unittest/utils/torch_ref.py (2)
RefHFModel(1245-1427)pad_data(1344-1427)
tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py (7)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)
process_weights_after_loading(447-448)process_weights_after_loading(755-783)pre_reload_weights(532-540)tensorrt_llm/_torch/modules/linear.py (4)
process_weights_after_loading(408-420)process_weights_after_loading(2567-2568)pre_reload_weights(513-520)pre_reload_weights(2573-2577)tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (2)
process_weights_after_loading(876-878)pre_reload_weights(880-884)tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (2)
process_weights_after_loading(1405-1407)pre_reload_weights(1409-1413)tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (2)
process_weights_after_loading(287-289)pre_reload_weights(291-295)tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2)
process_weights_after_loading(949-951)pre_reload_weights(953-957)tensorrt_llm/_torch/modules/fused_moe/interface.py (2)
process_weights_after_loading(529-530)pre_reload_weights(535-536)
tests/unittest/utils/torch_ref.py (1)
tensorrt_llm/_torch/models/modeling_auto.py (1)
AutoModelForCausalLM(10-52)
🪛 Ruff (0.14.10)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
1-1: The file is executable but no shebang is present
(EXE002)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
1-1: The file is executable but no shebang is present
(EXE002)
tensorrt_llm/llmapi/rlhf_utils.py
59-59: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py
126-126: Loop control variable dirs not used within loop body
Rename unused dirs to _dirs
(B007)
tests/unittest/utils/torch_ref.py
1373-1373: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
447-448: FusedMoEMethodBase.process_weights_after_loading is an empty method in an abstract base class, but has no abstract decorator
(B027)
502-502: Avoid specifying long messages outside the exception class
(TRY003)
601-601: f-string without any placeholders
Remove extraneous f prefix
(F541)
601-601: f-string without any placeholders
Remove extraneous f prefix
(F541)
603-603: f-string without any placeholders
Remove extraneous f prefix
(F541)
603-603: f-string without any placeholders
Remove extraneous f prefix
(F541)
605-605: f-string without any placeholders
Remove extraneous f prefix
(F541)
605-605: f-string without any placeholders
Remove extraneous f prefix
(F541)
734-734: f-string without any placeholders
Remove extraneous f prefix
(F541)
734-734: f-string without any placeholders
Remove extraneous f prefix
(F541)
736-736: f-string without any placeholders
Remove extraneous f prefix
(F541)
736-736: f-string without any placeholders
Remove extraneous f prefix
(F541)
738-738: f-string without any placeholders
Remove extraneous f prefix
(F541)
738-738: f-string without any placeholders
Remove extraneous f prefix
(F541)
941-942: Do not call getattr with a constant attribute value. It is not any safer than normal property access.
Replace getattr with attribute access
(B009)
945-946: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
949-950: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
953-954: Do not call getattr with a constant attribute value. It is not any safer than normal property access.
Replace getattr with attribute access
(B009)
957-958: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
961-962: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
975-976: Do not call getattr with a constant attribute value. It is not any safer than normal property access.
Replace getattr with attribute access
(B009)
979-980: Do not call getattr with a constant attribute value. It is not any safer than normal property access.
Replace getattr with attribute access
(B009)
tensorrt_llm/_torch/modules/linear.py
420-420: Avoid specifying long messages outside the exception class
(TRY003)
422-425: LinearMethodBase.process_weights_after_loading_vanilla is an empty method in an abstract base class, but has no abstract decorator
(B027)
427-430: LinearMethodBase.process_weights_after_loading_fused_qkv_linear is an empty method in an abstract base class, but has no abstract decorator
(B027)
432-436: LinearMethodBase.process_weights_after_loading_fused_gate_up_linear is an empty method in an abstract base class, but has no abstract decorator
(B027)
604-604: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
613-613: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
631-631: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
675-675: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
805-805: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
910-912: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
933-935: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
1078-1078: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
1106-1106: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (18)
tests/unittest/_torch/ray_orchestrator/multi_gpu/test_accuracy_with_allreduce_strategy.py (1)
166-234: LGTM: test now consistently usesRefHFModel.pad_data()+ per-GPURefHFModelinstances.The refactor reads cleanly and the per-device inference split is easier to follow than the prior inlined utilities.
tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py (1)
102-118: Potential mismatch risk: verifygeneration_logitsaligns with HF slicing contract.
run_generate()compares TRT-LLMgeneration_logitswith HF logits sliced from[prompt_max_len-1 : prompt_max_len-1+response_len]. This assumes:
output.outputs[0].token_idsare generated tokens only (no prompt tokens), andgeneration_logitscorrespond to next-token predictions for exactly those generated tokens.Please double-check this contract in the LLM output object; otherwise the comparison can be off-by-prompt or off-by-one.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py (1)
1405-1413: LGTM: lifecycle hooks mirror other MoE backends.
process_weights_after_loading()andpre_reload_weights()match the patterns in Cutlass/TRTLLMGen/WideEP and keep the API consistent.tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py (1)
1127-1146: LGTM: wrapper hooks enforce backend capabilities clearly.The assertions provide a clean failure mode if a backend doesn’t implement the new lifecycle hooks.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)
1-1: LGTM: using introspection to passallow_partial_loadingonly when supported.The dynamic kwarg forwarding keeps backward compatibility with quant methods that haven’t added the parameter.
tensorrt_llm/llmapi/rlhf_utils.py (1)
92-99: LGTM!The
process_weights_after_loadinghook is correctly integrated with properhasattrand_weights_removedchecks, consistent with the existingpost_load_weightspattern.tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
524-537: LGTM!The interface additions are well-designed:
allow_partial_loadingparameter has a default value maintaining backward compatibility- New lifecycle hooks (
process_weights_after_loading,pre_reload_weights) have empty default implementations, allowing subclasses to opt-in- Consistent with the existing
post_load_weightspatterntensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (2)
876-884: LGTM!The new lifecycle hooks are correctly implemented:
process_weights_after_loadingdefensively checks for the method before calling (optional hook)pre_reload_weightsuses assertion to fail fast if the quant method doesn't support it (required when called)This pattern is consistent with the other fused MoE backends.
1-1: LGTM!The
inspectimport is correctly added for the runtime introspection ofquant_method.load_weightssignature. The static analysis hint about missing shebang is a false positive for a Python module.tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (2)
287-295: LGTM!The lifecycle hooks are correctly implemented with the same pattern as
fused_moe_cutlass.pyand other backends:
process_weights_after_loadingwith optional delegation viahasattrpre_reload_weightswith required delegation via assertion
1-1: LGTM!The
inspectimport is correctly added for runtime signature introspection.tensorrt_llm/_torch/modules/fused_moe/quantization.py (3)
240-241:module.rebuild_tensor_metadatainitialization is good, but ensure it exists for all reload paths.This enables later tensor replacement bookkeeping; just make sure any code path that calls
pre_reload_weights()can’t run beforecreate_weights()has set this attribute.
810-819: Nice: metadata-tracked parameter replacement for padding.Using
replace_parameter_and_save_metadata()here is the right direction for later reload / CUDA-graph warnings.
1009-1013: DeepGemm post-load parameter replacement is good; confirm it matches reload semantics.Since you replace parameters and record metadata, ensure the pre-reload logic recreates tensors with correct dtype/shape/device for both
w3_w1_weight_scaling_factorandw2_weight_scaling_factor.Also applies to: 1024-1042
tensorrt_llm/_torch/modules/linear.py (4)
365-367: Post-load processing hook wiring looks good; ensure callers know they must invoke it after partial loads.The new
process_weights_after_loading*dispatch is a clean place to centralize “finalize” logic. Main gap is ensuring the partial-loading workflow actually callsLinear.process_weights_after_loading()once all shards are updated.Also applies to: 408-437
878-941: Use explicit tuple iteration for scale assignment consistency with rest of codebase.The scale loading code uses
zip(module.fused_weight_shard_indices_mapping.keys(), [...]), but identical patterns elsewhere in this file use explicit tuples:zip(('q', 'k', 'v'), ...). For clarity and consistency, use the explicit tuple pattern here too.
31-32: Review comment is incorrect: The actual Python version requirement is 3.10+, not 3.8+.The repository's
setup.pyexplicitly requirespython_requires=">=3.10, <4"(line 354). The code's use of PEP604 unions (|) and new-style type hints (list[str],dict[str, int]) at lines 53 and 62 is therefore valid and appropriate for the target Python version. While the CODING_GUIDELINES reference Python 3.8+, the actual binding constraint is Python 3.10+, making the current type annotations correct.Likely an incorrect or invalid review comment.
453-454: Replace hard-codeddevice="cuda"to support CPU and multi-GPU placement.The metadata is stored via
.to("meta"), which loses original device information. Usingtorch.empty_like(metadata, device="cuda")hardcodes CUDA and breaks CPU inference and multi-device scenarios. Store the original device alongside metadata or infer it from other module parameters that haven't been replaced yet.Also applies to: 513-521, 550-551, 2567-2569, 2573-2577, and tensorrt_llm/_torch/modules/fused_moe/quantization.py
⛔ Skipped due to learnings
Learnt from: yibinl-nvidia Repo: NVIDIA/TensorRT-LLM PR: 8923 File: tests/unittest/_torch/modeling/test_modeling_starcoder2.py:157-160 Timestamp: 2025-11-13T18:53:29.113Z Learning: In PyTorch 2.0+, torch.device() can be used as a context manager (e.g., `with torch.device('cuda'):`) to set the default device for tensor creation operations within that context. This is a valid pattern and should not be flagged as an error.Learnt from: nzmora-nvidia Repo: NVIDIA/TensorRT-LLM PR: 9163 File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113 Timestamp: 2025-11-14T11:22:03.729Z Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
| if allow_partial_loading: | ||
| assert isinstance( | ||
| self, (UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod, | ||
| DeepSeekFP8BlockScalesFusedMoEMethod, | ||
| DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm) | ||
| ), "Partial loading is only supported for unquantized and FP8 models" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name "quantization.py" -path "*fused_moe*"Repository: NVIDIA/TensorRT-LLM
Length of output: 118
🏁 Script executed:
wc -l ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 123
🏁 Script executed:
sed -n '330,350p' ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1073
🏁 Script executed:
sed -n '410,420p' ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 552
🏁 Script executed:
rg "assert isinstance" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -nRepository: NVIDIA/TensorRT-LLM
Length of output: 97
🏁 Script executed:
rg "assert " ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -n | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1351
🏁 Script executed:
rg "class FP8QDQFusedMoEMethod" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -A 30Repository: NVIDIA/TensorRT-LLM
Length of output: 1465
🏁 Script executed:
rg "def process_weights_after_loading" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 2 -A 25Repository: NVIDIA/TensorRT-LLM
Length of output: 2660
🏁 Script executed:
rg "delattr.*tmp_" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 5 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 464
🏁 Script executed:
rg "def post_load_weights" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 2 -A 20 | grep -A 20 "class FP8QDQFusedMoEMethod" | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
sed -n '1900,1950p' ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2478
🏁 Script executed:
rg "class FP8QDQFusedMoEMethod" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -A 200 | grep -A 10 "def post_load_weights"Repository: NVIDIA/TensorRT-LLM
Length of output: 660
🏁 Script executed:
rg "process_weights_after_loading" ./tensorrt_llm/_torch/modules/fused_moe/quantization.py -nRepository: NVIDIA/TensorRT-LLM
Length of output: 269
🏁 Script executed:
sed -n '447,490p' ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1832
🏁 Script executed:
sed -n '755,810p' ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3057
🏁 Script executed:
sed -n '330,360p' ./tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1551
Don't use assert for runtime feature gating; raise a real exception (asserts can be stripped).
If partial loading is unsupported for a method, this should be a deterministic runtime error (e.g., ValueError / NotImplementedError), not an assert.
Also, when allow_partial_loading=True, the code skips process_weights_after_loading() (line 414-415). For FP8QDQFusedMoEMethod, this is a critical issue: the temporary buffers (tmp_w3_w1_weight_scale, tmp_w2_weight_scale, etc.) won't be cleaned up via delattr() and final scales (fc31_dequant, fc2_quant, etc.) won't be computed. An explicit finalization must occur elsewhere, or this becomes a resource/correctness leak.
Proposed fix for assert
- if allow_partial_loading:
- assert isinstance(
- self, (UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod,
- DeepSeekFP8BlockScalesFusedMoEMethod,
- DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm)
- ), "Partial loading is only supported for unquantized and FP8 models"
+ if allow_partial_loading and not isinstance(
+ self,
+ (
+ UnquantizedFusedMoEMethod,
+ FP8QDQFusedMoEMethod,
+ DeepSeekFP8BlockScalesFusedMoEMethod,
+ DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
+ ),
+ ):
+ raise NotImplementedError(
+ "Partial loading is only supported for unquantized and FP8 models"
+ )| dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0) | ||
| if w1_weight_shard is not None and w1_weight_shard.shape[0] != 0: | ||
| w1_weight_shard_viewed = w1_weight_shard.contiguous().view( | ||
| dst_w3_w1_weight.dtype) | ||
| if w1_weight_shard_viewed.shape[0] == dst_w3_w1_weight.shape[0]: | ||
| # w3_weight (gate_proj) should be empty for Nemotron-H MoE model. | ||
| dst_w3_w1_weight.copy_(w1_weight_shard_viewed, | ||
| non_blocking=True) | ||
| elif w1_weight_shard_viewed.shape[0] == dst_w1_weight.shape[0]: | ||
| dst_w1_weight.copy_(w1_weight_shard_viewed, non_blocking=True) | ||
| else: | ||
| raise ValueError("Shape mismatch!") | ||
| if w3_weight_shard is not None and w3_weight_shard.shape[0] != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial w1/w3 load logic: shape checks look fragile around dtype “viewing”; improve error detail.
Because you compare w1_weight_shard_viewed.shape[0] against destination shapes after a dtype reinterpretation, mismatches can be confusing to debug. Consider including source/dst shapes + dtypes in the exception message (and/or comparing pre-view shapes if that’s the real intent).
🧰 Tools
🪛 Ruff (0.14.10)
502-502: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 491 -
503, The shape-check logic around w1_weight_shard_viewed vs
dst_w3_w1_weight/dst_w1_weight is fragile because dtype-viewing can hide the
original shape; update the ValueError to include both source and destination
shapes and dtypes (e.g., include w1_weight_shard.shape,
w1_weight_shard_viewed.shape, dst_w3_w1_weight.shape, dst_w1_weight.shape and
their .dtype) and/or perform the comparison using the pre-view shape
(w1_weight_shard.shape) before reinterpreting dtype; ensure the same pattern is
applied for w3_weight_shard checks (referencing dst_w3_weight, dst_w1_weight,
w3_weight_shard and any viewed tensors) so mismatches surface clear diagnostic
info.
| def pre_reload_weights(self, module: torch.nn.Module): | ||
| for param_name, metadata in module.rebuild_tensor_metadata.items(): | ||
| logger.warning( | ||
| f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs." | ||
| ) | ||
| param = torch.nn.Parameter(torch.empty_like(metadata, | ||
| device="cuda"), | ||
| requires_grad=False) | ||
| module.register_parameter(param_name, param) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
fd -t f "quantization.py" tensorrt_llm/_torch/modules/Repository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
wc -l tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
sed -n '520,560p' tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
rg -n "rebuild_tensor_metadata" tensorrt_llm/_torch/ -A 2 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 5871
🏁 Script executed:
rg -n "def replace_parameter_and_save_metadata" tensorrt_llm/_torch/ -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 916
🏁 Script executed:
sed -n '530,550p' tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1040
🏁 Script executed:
sed -n '510,560p' tensorrt_llm/_torch/modules/linear.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2663
🏁 Script executed:
rg -n "pre_reload_weights" tensorrt_llm/ -B 3 -A 15Repository: NVIDIA/TensorRT-LLM
Length of output: 18860
🏁 Script executed:
rg -n "cpu\|CPU" tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
sed -n '415,435p' tensorrt_llm/_torch/utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 661
pre_reload_weights() hard-codes CUDA device; should recreate parameters on the module's device.
The hard-coded device="cuda" breaks CPU inference flows and multi-GPU setups where parameters may be placed on different devices. The metadata is stored on the "meta" device (line 427, utils.py), which doesn't preserve device information. Either store the original device in metadata when saving the parameter, or infer it from existing module parameters using next(module.parameters()).device as a fallback.
This pattern appears in multiple quantization methods across linear.py, quantization.py, and several MoE backends (wide_ep, trtllm_gen, triton, cutlass).
| if w1_input_scale is not None and w1_input_scale.numel() != 0: | ||
| w1_input_scale = w1_input_scale[...].reshape([]) | ||
| dst_fc31_input_scale[0].copy_(w1_input_scale) | ||
| if w3_input_scale is not None and w3_input_scale.numel() != 0: | ||
| w3_input_scale = w3_input_scale[...].reshape([]) | ||
| dst_fc31_input_scale[1].copy_(w3_input_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
wc -l tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 182
🏁 Script executed:
sed -n '565,620p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 3250
🏁 Script executed:
grep -n "load_activation_scales_fp8_qdq" tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 202
🏁 Script executed:
grep -n "tmp_fc31_input_scale\|tmp_fc2_input_scale" tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2306
🏁 Script executed:
sed -n '750,790p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 2364
🏁 Script executed:
grep -B5 "delattr(module, 'tmp_fc31_input_scale')" tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 249
🏁 Script executed:
grep -n "process_weights_after_loading\|partial.*loading\|VANILLA\|FUSED_GATE_UP_PROJ" tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 2090
🏁 Script executed:
head -20 tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 777
🏁 Script executed:
head -5 tensorrt_llm/_torch/modules/fused_moe/quantization.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 185
🏁 Script executed:
sed -n '1,10p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 381
🏁 Script executed:
sed -n '590,620p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 1957
Fix placeholder-less f-strings and add missing copyright header; also address temporary buffer lifecycle in partial-loading mode.
- Lines 600-604 contain f-strings with no placeholders (
f"gate_up_proj_input_scale",f"down_proj_input_scale") that should be plain strings (Ruff F541). load_activation_scales_fp8_qdq()allocatesmodule.tmp_fc31_input_scaleandmodule.tmp_fc2_input_scale(lines 579–590), but these are only deleted inprocess_weights_after_loading()whenallow_partial_loading=False. When partial-loading is enabled, the temporary buffers persist on the module until external finalization.- File is missing the required NVIDIA copyright header.
Mechanical f-string fix
- w1_input_scale = weights[
- f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None
- w3_input_scale = weights[
- f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None
- w2_input_scale = weights[
- f"down_proj_input_scale"] if f"down_proj_input_scale" in weights else None
+ w1_input_scale = weights[
+ "gate_up_proj_input_scale"] if "gate_up_proj_input_scale" in weights else None
+ w3_input_scale = weights[
+ "gate_up_proj_input_scale"] if "gate_up_proj_input_scale" in weights else None
+ w2_input_scale = weights[
+ "down_proj_input_scale"] if "down_proj_input_scale" in weights else None🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 565 -
570, Replace placeholder f-strings like f"gate_up_proj_input_scale" and
f"down_proj_input_scale" with plain string literals; in
load_activation_scales_fp8_qdq() stop storing temporary tensors as long-lived
module attributes (module.tmp_fc31_input_scale/module.tmp_fc2_input_scale) or
ensure they are removed immediately after they are copied into
dst_fc31_input_scale/dst_fc2_input_scale—e.g., copy the values then
delattr(module, "tmp_fc31_input_scale") / delattr(module, "tmp_fc2_input_scale")
or set them to None so partial-loading doesn't leak buffers; update
process_weights_after_loading() behavior to be consistent but do not rely on it
for cleanup when partial-loading is enabled; finally add the required NVIDIA
copyright header at the top of the file.
| load_activation_scales_fp8_qdq(module, weights) | ||
|
|
||
| # Step2: Load weight scales | ||
| if not hasattr(module, 'tmp_w3_w1_weight_scale'): | ||
| module.tmp_w3_w1_weight_scale = torch.empty( | ||
| (module.expert_size_per_partition, 2), | ||
| dtype=torch.float32, | ||
| device=module.fc31_dequant.device) | ||
| if not hasattr(module, 'tmp_w2_weight_scale'): | ||
| module.tmp_w2_weight_scale = torch.empty( | ||
| module.expert_size_per_partition, | ||
| dtype=torch.float32, | ||
| device=module.fc2_dequant.device) | ||
| tmp_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale | ||
| tmp_w2_weight_scale = module.tmp_w2_weight_scale | ||
|
|
||
| for local_slot_id, expert_id in enumerate( | ||
| module.initial_local_expert_ids): | ||
| if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: | ||
| w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] | ||
| w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] | ||
| w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] | ||
| w1_weight_scale = weights[ | ||
| f"{expert_id}.w1.weight_scale"] if f"{expert_id}.w1.weight_scale" in weights else None | ||
| w3_weight_scale = weights[ | ||
| f"{expert_id}.w3.weight_scale"] if f"{expert_id}.w3.weight_scale" in weights else None | ||
| w2_weight_scale = weights[ | ||
| f"{expert_id}.w2.weight_scale"] if f"{expert_id}.w2.weight_scale" in weights else None | ||
| elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: | ||
| w1_weight_scale = weights[f"gate_up_proj_weight_scale"] | ||
| w3_weight_scale = weights[f"gate_up_proj_weight_scale"] | ||
| w2_weight_scale = weights[f"down_proj_weight_scale"] | ||
| w1_weight_scale = weights[ | ||
| f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None | ||
| w3_weight_scale = weights[ | ||
| f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None | ||
| w2_weight_scale = weights[ | ||
| f"down_proj_weight_scale"] if f"down_proj_weight_scale" in weights else None | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" | ||
| ) | ||
|
|
||
| expert_idx = local_slot_id | ||
|
|
||
| self.load_expert_w3_w1_weight_scale_fp8_qdq( | ||
| w1_weight_scale, w3_weight_scale, | ||
| tmp_w3_w1_weight_scale[expert_idx]) | ||
|
|
||
| if w1_weight_scale is not None or w3_weight_scale is not None: | ||
| self.load_expert_w3_w1_weight_scale_fp8_qdq( | ||
| w1_weight_scale, w3_weight_scale, | ||
| tmp_w3_w1_weight_scale[expert_idx]) | ||
|
|
||
| if w2_weight_scale is not None: | ||
| self.load_expert_w2_weight_scale_fp8( | ||
| w2_weight_scale, tmp_w2_weight_scale[expert_idx]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Get the full file context around lines 707-784
head -n 800 tensorrt_llm/_torch/modules/fused_moe/quantization.py | tail -n 150 | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 8548
🏁 Script executed:
# Also check what comes after line 784 to see the max() operation
sed -n '700,850p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | cat -nRepository: NVIDIA/TensorRT-LLM
Length of output: 8740
🏁 Script executed:
# Search for max() operations on tmp_* variables to see post-load finalization
rg "tmp_w3_w1_weight_scale|tmp_w2_weight_scale" tensorrt_llm/_torch/modules/fused_moe/quantization.py -A 5 -B 2Repository: NVIDIA/TensorRT-LLM
Length of output: 2692
🏁 Script executed:
# Check if there's any initialization of these tmp buffers to -inf or 0
rg "tmp_w3_w1_weight_scale.*=.*torch\.(empty|full|zeros|ones)" tensorrt_llm/_torch/modules/fused_moe/quantization.py -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 262
🏁 Script executed:
# Find and examine load_activation_scales_fp8_qdq function
rg -B 5 "def load_activation_scales_fp8_qdq" tensorrt_llm/_torch/modules/fused_moe/quantization.py -A 30Repository: NVIDIA/TensorRT-LLM
Length of output: 2031
🏁 Script executed:
# Check if there's any initialization or guard logic before calling .max()
rg "max\(\)" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 3 -A 1Repository: NVIDIA/TensorRT-LLM
Length of output: 2482
🏁 Script executed:
# Also check if there are any assertions or validation before the max() calls
rg -B 5 "\.max\(\)" tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 2519
🏁 Script executed:
# Check the complete load_activation_scales_fp8_qdq to see if all slots are guaranteed to be written
rg "def load_activation_scales_fp8_qdq" tensorrt_llm/_torch/modules/fused_moe/quantization.py -A 60 | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 3242
🏁 Script executed:
# Verify if there's any guarantee that module.initial_local_expert_ids covers all experts
rg "initial_local_expert_ids" tensorrt_llm/_torch/modules/fused_moe/quantization.py -B 2 -A 2 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 1652
🏁 Script executed:
# Check if there's any validation or guard that ensures all slots are written before calling .max()
rg -B 10 "tmp_w3_w1_weight_scale.max\(\|tmp_w2_weight_scale.max\(\|tmp_fc31_input_scale.max\(\|tmp_fc2_input_scale.max\(\)" tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Verify the complete load_activation_scales_fp8_qdq function for any safeguards
sed -n '/^def load_activation_scales_fp8_qdq/,/^def [^_]/p' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 2303
Initialize tmp scale buffers and add validation before .max() operations to prevent undefined behavior from uninitialized slots.
When scales are missing for an expert, tmp buffer slots remain uninitialized from torch.empty(). The unconditional .max() calls on these buffers in process_weights_after_loading() (lines 771, 775, and similar in other methods) operate on undefined memory. This affects both weight scales (tmp_w3_w1_weight_scale, tmp_w2_weight_scale) and input scales (tmp_fc31_input_scale, tmp_fc2_input_scale).
Initialize tmp buffers to 0 or -inf (as appropriate for the aggregation operation), and track whether any scales were actually loaded for each slot. Skip the aggregation or apply a fallback if all scales for a slot remain missing.
🧰 Tools
🪛 Ruff (0.14.10)
734-734: f-string without any placeholders
Remove extraneous f prefix
(F541)
734-734: f-string without any placeholders
Remove extraneous f prefix
(F541)
736-736: f-string without any placeholders
Remove extraneous f prefix
(F541)
736-736: f-string without any placeholders
Remove extraneous f prefix
(F541)
738-738: f-string without any placeholders
Remove extraneous f prefix
(F541)
738-738: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 707 -
754, The tmp scale buffers (module.tmp_w3_w1_weight_scale,
module.tmp_w2_weight_scale, module.tmp_fc31_input_scale,
module.tmp_fc2_input_scale) must be initialized to safe values (e.g., zeros or
-inf for max-reduction) instead of torch.empty(), and we must record per-slot
whether a scale was loaded so downstream max()/aggregation in
process_weights_after_loading() is not called on uninitialized memory; update
the allocation in load_activation_scales_fp8_qdq / the block creating tmp_* to
use a deterministic fill (torch.zeros or torch.full(..., -float('inf')) as
appropriate) and add a boolean mask or loaded_count per expert that
load_expert_w3_w1_weight_scale_fp8_qdq and load_expert_w2_weight_scale_fp8 set
when they actually load a scale, then in process_weights_after_loading() check
that mask/count for each slot and skip aggregation or apply a safe fallback when
no scales were loaded for that slot.
| moe_load_balancer.finalize_model() | ||
| logger.info("moe_load_balancer finalize model done") | ||
| self.engine.reset_prefix_cache() | ||
| delattr(self.engine.model_engine.model, "first_pre_reload_weights") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential AttributeError if first_pre_reload_weights was never set.
If update_weights is called with ipc_handles=None (finalize path) without a prior call that set first_pre_reload_weights, delattr will raise AttributeError. Consider using hasattr check or try-except:
- delattr(self.engine.model_engine.model, "first_pre_reload_weights")
+ if hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
+ delattr(self.engine.model_engine.model, "first_pre_reload_weights")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| delattr(self.engine.model_engine.model, "first_pre_reload_weights") | |
| if hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): | |
| delattr(self.engine.model_engine.model, "first_pre_reload_weights") |
🤖 Prompt for AI Agents
In @tensorrt_llm/llmapi/rlhf_utils.py at line 107, The code unconditionally
calls delattr(self.engine.model_engine.model, "first_pre_reload_weights") which
can raise AttributeError if that attribute was never set (e.g., update_weights
called with ipc_handles=None on the finalize path); update the logic in the
update_weights method to either check hasattr(self.engine.model_engine.model,
"first_pre_reload_weights") before calling delattr, or wrap the delattr call in
a try/except AttributeError and ignore the exception, referencing the attribute
name "first_pre_reload_weights" and the object self.engine.model_engine.model to
locate the change.
| class RefHFModelWithIPCHandles(RefHFModel): | ||
| def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4): | ||
| self.device_id = device_id | ||
| config = AutoConfig.from_pretrained(model_dir) | ||
| config.num_hidden_layers = num_hidden_layers | ||
| self.model = AutoModelForCausalLM.from_pretrained( | ||
| model_name, torch_dtype=torch.bfloat16 | ||
| ).to("cuda") | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| self.cuda_device = torch.cuda.current_device() | ||
| model_dir, config=config, torch_dtype=torch.bfloat16, attn_implementation="eager" | ||
| ).to(f"cuda:{device_id}") | ||
| self.all_weights = {} | ||
| self.device_uuid = [HFModel.get_device_uuid(i) for i in range(torch.cuda.device_count())] | ||
| self.device_uuid = [get_device_uuid(i) for i in range(torch.cuda.device_count())] | ||
| self._replicate_weights() | ||
|
|
||
| @staticmethod | ||
| def get_device_uuid(cuda_device: int): | ||
| from tensorrt_llm._torch.utils import get_device_uuid | ||
|
|
||
| return get_device_uuid(cuda_device) | ||
|
|
||
| def _replicate_weights(self): | ||
| model_weights = [] | ||
| for n, p in self.model.named_parameters(): | ||
| model_weights.append((n, p.detach().clone())) | ||
|
|
||
| self.all_weights[self.cuda_device] = model_weights | ||
| self.all_weights[self.device_id] = model_weights | ||
| for i in range(torch.cuda.device_count()): | ||
| if i != self.cuda_device: | ||
| if i != self.device_id: | ||
| cur_weights = [] | ||
| for n, p in self.all_weights[self.cuda_device]: | ||
| for n, p in self.all_weights[self.device_id]: | ||
| cur_weights.append((n, p.to("cuda:" + str(i)))) | ||
| self.all_weights[i] = cur_weights | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid replicating all model weights onto every CUDA device in RefHFModelWithIPCHandles (likely OOM on CI).
_replicate_weights() clones weights and then copies them to every device from torch.cuda.device_count(). Even with num_hidden_layers=1, this can be extremely large for the bigger Qwen variants and will scale with GPU count.
Given your call sites use get_weight_ipc_handles([0], ...), you can store weights only on device_id and (optionally) materialize other devices on demand.
Proposed patch (lazy replication only if requested)
class RefHFModelWithIPCHandles(RefHFModel):
def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4):
self.device_id = device_id
@@
self.all_weights = {}
self.device_uuid = [get_device_uuid(i) for i in range(torch.cuda.device_count())]
self._replicate_weights()
def _replicate_weights(self):
model_weights = []
for n, p in self.model.named_parameters():
model_weights.append((n, p.detach().clone()))
-
- self.all_weights[self.device_id] = model_weights
- for i in range(torch.cuda.device_count()):
- if i != self.device_id:
- cur_weights = []
- for n, p in self.all_weights[self.device_id]:
- cur_weights.append((n, p.to("cuda:" + str(i))))
- self.all_weights[i] = cur_weights
+ self.all_weights[self.device_id] = model_weights
+
+ def _ensure_weights_on_device(self, device: int) -> None:
+ if device in self.all_weights:
+ return
+ if device == self.device_id:
+ return
+ src = self.all_weights[self.device_id]
+ self.all_weights[device] = [(n, t.to(f"cuda:{device}")) for n, t in src] def get_weight_ipc_handles(
self,
device_ids: Optional[List[int]] = None,
weight_filter: Optional[Callable[[str], bool]] = None,
):
@@
- for device in device_list:
+ for device in device_list:
+ self._ensure_weights_on_device(device)
all_handles = []
for item in self.all_weights[device]:
name, p = item🤖 Prompt for AI Agents
In @tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py
around lines 20 - 44, Refactor RefHFModelWithIPCHandles._replicate_weights to
avoid copying model weights to every CUDA device; instead only store the cloned
weights for self.device_id in self.all_weights and do not iterate over
torch.cuda.device_count(); add (or update) a lazy materialization path that
converts weights to another device only when requested (e.g., from
get_weight_ipc_handles or a new method like materialize_weights_on(device_idx)),
so other devices are created on-demand and prevent OOM on CI.
| def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4): | ||
| if os.path.exists(dst_folder): | ||
| shutil.rmtree(dst_folder) | ||
| os.makedirs(dst_folder) | ||
|
|
||
| for root, dirs, files in os.walk(src_folder): | ||
| rel_path = os.path.relpath(root, src_folder) | ||
| dest_dir = os.path.join(dst_folder, rel_path) | ||
|
|
||
| if not os.path.exists(dest_dir): | ||
| os.makedirs(dest_dir) | ||
|
|
||
| for file in files: | ||
| src_path = os.path.join(root, file) | ||
| dest_path = os.path.join(dest_dir, file) | ||
| if "safetensor" in file: | ||
| continue | ||
|
|
||
| if file == "config.json": | ||
| with open(src_path, "r", encoding="utf-8") as f: | ||
| config = json.load(f) | ||
| config["num_hidden_layers"] = num_hidden_layers | ||
| with open(dest_path, "w", encoding="utf-8") as f: | ||
| json.dump(config, f, indent=2, ensure_ascii=False) | ||
| else: | ||
| shutil.copy2(src_path, dest_path) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, locate and read the file to see its actual state
fd "test_llm_update_weights.py" -x cat -n {}Repository: NVIDIA/TensorRT-LLM
Length of output: 11922
🏁 Script executed:
# Verify Ruff B007 definition and search for best practices
web_search
Ruff B007 flake8-bugbear unused loop variableRepository: NVIDIA/TensorRT-LLM
Length of output: 156
Rename unused loop variable and tighten directory creation.
Rename dirs to _dirs on line 126 to address Ruff B007. Replace lines 130-131 with os.makedirs(dest_dir, exist_ok=True) to eliminate the explicit existence check.
🧰 Tools
🪛 Ruff (0.14.10)
126-126: Loop control variable dirs not used within loop body
Rename unused dirs to _dirs
(B007)
🤖 Prompt for AI Agents
In @tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py
around lines 121 - 147, In process_and_copy_folder, rename the unused loop
variable dirs to _dirs to satisfy Ruff B007, and replace the explicit existence
check plus os.makedirs for dest_dir with a single os.makedirs(dest_dir,
exist_ok=True) call so directory creation is idempotent; update the for root,
_dirs, files in os.walk(...) and change the dest_dir handling to use
os.makedirs(..., exist_ok=True) before writing or copying files.
| class RefHFModel: | ||
|
|
||
| def __init__(self, | ||
| model_dir: str, | ||
| device_id: int = 0, | ||
| additional_model_kargs: Optional[Dict[str, Any]] = None): | ||
| self.device_id = device_id | ||
| self.model = AutoModelForCausalLM.from_pretrained( | ||
| model_dir, **(additional_model_kargs or {})).to(f"cuda:{device_id}") | ||
|
|
||
| def generate_batch_with_padding( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| responses: List[List[int]], | ||
| prompt_max_len: int = 1024, | ||
| micro_batch_size: int = 16, | ||
| return_logits: bool = False, | ||
| ): | ||
| """ | ||
| Synchronous inference on a batch with micro-batching. | ||
| Directly extracts response logprobs to save memory. | ||
| Args: | ||
| input_ids: [batch_size, seq_len] | ||
| attention_mask: [batch_size, seq_len] | ||
| position_ids: [batch_size, seq_len] | ||
| responses: List of response token IDs for each sample | ||
| prompt_max_len: Maximum prompt length (default 1024) | ||
| micro_batch_size: Size of each micro batch to avoid OOM | ||
| return_logits: Whether to return logits, If True, return logits, otherwise return logprobs | ||
| Returns: | ||
| List of logits or logprobs tensors, one per sample [response_len] | ||
| """ | ||
| # Move tensors to the correct device | ||
| input_ids = input_ids.to(f"cuda:{self.device_id}") | ||
| attention_mask = attention_mask.to(f"cuda:{self.device_id}") | ||
| position_ids = position_ids.to(f"cuda:{self.device_id}") | ||
|
|
||
| batch_size = input_ids.shape[0] | ||
| num_micro_batches = (batch_size + micro_batch_size - | ||
| 1) // micro_batch_size | ||
|
|
||
| ref_results = [] | ||
|
|
||
| with torch.no_grad(): | ||
| for micro_idx in range(num_micro_batches): | ||
| start_idx = micro_idx * micro_batch_size | ||
| end_idx = min((micro_idx + 1) * micro_batch_size, batch_size) | ||
|
|
||
| # Extract micro batch | ||
| micro_input_ids = input_ids[start_idx:end_idx] | ||
| micro_attention_mask = attention_mask[start_idx:end_idx] | ||
| micro_position_ids = position_ids[start_idx:end_idx] | ||
|
|
||
| # Forward pass | ||
| outputs = self.model( | ||
| input_ids=micro_input_ids, | ||
| attention_mask=micro_attention_mask, | ||
| position_ids=micro_position_ids, | ||
| ) | ||
| # Extract response logprobs for each sample in this micro batch | ||
| for i in range(outputs.logits.shape[0]): | ||
| sample_idx = start_idx + i | ||
| response = responses[sample_idx] | ||
| response_len = len(response) | ||
|
|
||
| # Extract logits for predicting response tokens | ||
| # For predicting response[j], we need logits at position prompt_max_len-1+j | ||
| response_logits = outputs.logits[i, prompt_max_len - | ||
| 1:prompt_max_len - 1 + | ||
| response_len, :] | ||
| if return_logits: | ||
| ref_results.append(response_logits) | ||
| else: | ||
| # Convert to logprobs | ||
| response_logprobs = torch.log_softmax(response_logits, | ||
| dim=-1) | ||
|
|
||
| # Extract logprobs for the actual generated tokens | ||
| response_tensor = torch.tensor( | ||
| response, | ||
| dtype=torch.long, | ||
| device=response_logprobs.device) | ||
| ref_logprob_for_tokens = torch.gather( | ||
| response_logprobs, | ||
| dim=-1, | ||
| index=response_tensor.unsqueeze(-1)).squeeze(-1) | ||
|
|
||
| ref_results.append(ref_logprob_for_tokens) | ||
|
|
||
| # Free memory immediately after processing each micro batch | ||
| del outputs | ||
| torch.cuda.empty_cache() | ||
|
|
||
| return ref_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix OOM risk: return_logits=True currently retains the full outputs.logits backing storage.
response_logits = outputs.logits[i, ...] is a view; appending it means del outputs + empty_cache() won’t actually release the large logits tensor. For the update-weights tests (which request logits), this can balloon memory.
Proposed patch
- response_logits = outputs.logits[i, prompt_max_len -
- 1:prompt_max_len - 1 +
- response_len, :]
+ response_logits = outputs.logits[
+ i,
+ prompt_max_len - 1:prompt_max_len - 1 + response_len,
+ :,
+ ]
if return_logits:
- ref_results.append(response_logits)
+ ref_results.append(response_logits.detach().clone())
else:
# Convert to logprobs
response_logprobs = torch.log_softmax(response_logits,
dim=-1)| @staticmethod | ||
| def pad_data( | ||
| original_prompts: List[List[int]], | ||
| generated_token_ids_list: List[List[int]], | ||
| prompt_max_len: int = 1024, | ||
| response_max_len: int = 1024, | ||
| pad_token_id: int = 0, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Pad the data to the maximum length. | ||
| Structure: | ||
| [left_pad | actual_prompt | actual_response | right_pad] | ||
| |<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->| | ||
| Args: | ||
| original_prompts: List of prompt token IDs, len = batch_size | ||
| generated_token_ids_list: List of response token IDs, len = batch_size | ||
| prompt_max_len: Maximum length for prompt section (default 1024) | ||
| response_max_len: Maximum length for response section (default 1024) | ||
| pad_token_id: Token ID for padding (default 0) | ||
| Returns: | ||
| input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len] | ||
| attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len] | ||
| position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len] | ||
| """ | ||
| batch_size = len(original_prompts) | ||
| total_len = prompt_max_len + response_max_len | ||
|
|
||
| for i, (prompt, response) in enumerate( | ||
| zip(original_prompts, generated_token_ids_list)): | ||
| assert len(prompt) <= prompt_max_len, ( | ||
| f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}" | ||
| ) | ||
| assert len(response) <= response_max_len, ( | ||
| f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}" | ||
| ) | ||
|
|
||
| # Build batch tensors [batch_size, total_len] | ||
| batch_input_ids = torch.full((batch_size, total_len), | ||
| pad_token_id, | ||
| dtype=torch.long, | ||
| device="cuda") | ||
| batch_attention_mask = torch.zeros((batch_size, total_len), | ||
| dtype=torch.long, | ||
| device="cuda") | ||
| batch_position_ids = torch.zeros((batch_size, total_len), | ||
| dtype=torch.long, | ||
| device="cuda") | ||
|
|
||
| response_lens = [] | ||
|
|
||
| for i in range(batch_size): | ||
| prompt_tokens = original_prompts[i] | ||
| response_tokens = generated_token_ids_list[i] | ||
|
|
||
| prompt_len = len(prompt_tokens) | ||
| response_len = len(response_tokens) | ||
| response_lens.append(response_len) | ||
|
|
||
| left_pad_len = prompt_max_len - prompt_len | ||
|
|
||
| # Fill input_ids: [left_pad | prompt | response | right_pad] | ||
| prompt_start = left_pad_len | ||
| prompt_end = prompt_max_len | ||
| response_start = prompt_max_len | ||
| response_end = prompt_max_len + response_len | ||
|
|
||
| batch_input_ids[i, prompt_start:prompt_end] = torch.tensor( | ||
| prompt_tokens, dtype=torch.long, device="cuda") | ||
| batch_input_ids[i, response_start:response_end] = torch.tensor( | ||
| response_tokens, dtype=torch.long, device="cuda") | ||
|
|
||
| # Fill attention_mask: 1 for actual tokens, 0 for padding | ||
| batch_attention_mask[i, prompt_start:response_end] = 1 | ||
|
|
||
| # Fill position_ids: sequential for actual tokens | ||
| actual_seq_len = prompt_len + response_len | ||
| batch_position_ids[i, prompt_start:response_end] = torch.arange( | ||
| actual_seq_len, dtype=torch.long, device="cuda") | ||
| # Right padding keeps the last position value | ||
| if response_len < response_max_len: | ||
| batch_position_ids[i, response_end:] = actual_seq_len - 1 | ||
|
|
||
| return batch_input_ids, batch_attention_mask, batch_position_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's find and examine the file
find . -name "torch_ref.py" -type f | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 98
🏁 Script executed:
# Read the file around the specified lines
cat -n tests/unittest/utils/torch_ref.py | sed -n '1343,1427p'Repository: NVIDIA/TensorRT-LLM
Length of output: 4573
🏁 Script executed:
head -20 tests/unittest/utils/torch_ref.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 878
Remove dead code and ensure Python 3.8 compatibility.
response_lensis populated at line 1401 but never used or returned; remove it.- Use explicit length check instead of
zip(..., strict=...)for Python 3.8+ compatibility.
Proposed patch
- for i, (prompt, response) in enumerate(
- zip(original_prompts, generated_token_ids_list)):
+ assert len(original_prompts) == len(generated_token_ids_list)
+ for i, (prompt, response) in enumerate(
+ zip(original_prompts, generated_token_ids_list)):
assert len(prompt) <= prompt_max_len, (
f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}"
)
assert len(response) <= response_max_len, (
f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}"
)
@@
- response_lens = []
-
for i in range(batch_size):
@@
- response_lens.append(response_len)
-
left_pad_len = prompt_max_len - prompt_len🧰 Tools
🪛 Ruff (0.14.10)
1373-1373: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
In @tests/unittest/utils/torch_ref.py around lines 1343 - 1427, The pad_data
static method contains dead code and a Python-version-incompatible pattern:
remove the unused response_lens list (created at response_lens = [] and appended
to inside the loop) and eliminate any reliance on zip(..., strict=...) by
keeping the current explicit pairing approach (use range(len(original_prompts))
or manually check lengths) so pad_data (the function) only iterates with indices
to validate and populate batch tensors; ensure no references to response_lens
remain and that input validation compares lengths of original_prompts and
generated_token_ids_list explicitly before building tensors.
| def pre_reload_weights(self): | ||
| """ | ||
| Pre reload weights - delegated to backend | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the empty line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored in this commit.
| allow_partial_loading: bool = False): | ||
| raise NotImplementedError | ||
|
|
||
| def process_weights_after_loading(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some docstrings to these new APIs? These can be helpful for other developers. It is better to have these three APIs documented as they look related.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in this new commit.
HuiGao-NV
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: Shuyi Xiong <[email protected]>
Signed-off-by: shuyixiong <[email protected]>
Signed-off-by: shuyixiong <[email protected]>
e59b2db to
7003908
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #32275 [ run ] triggered by Bot. Commit: |
|
PR_Github #32275 [ run ] completed with state
|
Signed-off-by: shuyixiong <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #32315 [ run ] triggered by Bot. Commit: |
|
PR_Github #32315 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #32383 [ run ] triggered by Bot. Commit: |
|
PR_Github #32383 [ run ] completed with state
|
Summary by CodeRabbit
New Features
Refactor
load_weightsfunction. Usingallow_partial_loadingparameter to defer quantization when needed, currently only used for the update_weights workflow.pre_reload_weightsstep is introduced to revert tensors to their original creation shape which are modified inpost_load_weights. The original shape information is preserved using meta tensors. NOTE thatpre_reload_weightsis now incompatible with CUDA Graph. A warning will be issued when tensor reconstruction is required.Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.